Context¶
Recent studies have shown that breast cancer continues to be the leading cause of death among women over the world. If detected at an early stage, it can be cured in 9 out of 10 cases.
Automated detection and segmentation of cells from images are the crucial and fundamental steps for the measurement of cellular morphology that is crucial for brest cancer diagnosis and prognosis.
In this notebook, you will learn how to train a segmentation as UNet with monai - a framwork based Pytorch Stadard for healthcare imaging.
! pip install monai
Collecting monai Downloading monai-1.4.0-py3-none-any.whl.metadata (11 kB) Requirement already satisfied: numpy<2.0,>=1.24 in /usr/local/lib/python3.10/dist-packages (from monai) (1.26.4) Requirement already satisfied: torch>=1.9 in /usr/local/lib/python3.10/dist-packages (from monai) (2.4.1+cu121) Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->monai) (3.16.1) Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->monai) (4.12.2) Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->monai) (1.13.3) Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->monai) (3.3) Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->monai) (3.1.4) Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->monai) (2024.6.1) Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.9->monai) (2.1.5) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.9->monai) (1.3.0) Downloading monai-1.4.0-py3-none-any.whl (1.5 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.5/1.5 MB 24.0 MB/s eta 0:00:00a 0:00:01 Installing collected packages: monai Successfully installed monai-1.4.0
import os
import re
import torch
from torch.optim import Adam
from torch.utils.data import random_split, DataLoader
import matplotlib.pyplot as plt
from PIL import Image
from monai.transforms import Compose, LoadImage, EnsureChannelFirst, ToTensor, NormalizeIntensity, ScaleIntensity
from monai.data import PILReader
from monai.networks.nets import UNet, AttentionUnet
from monai.losses import DiceLoss
Monai¶
MONAI is a pytorch based open source AI framework launched by NVIDIA and King’s College London. It is integrated with training and modelling workflows in a native PyTorch Standard. t several places.
Install monai
! pip install monai
/usr/lib/python3.10/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. pid, fd = os.forkpty()
Requirement already satisfied: monai in /usr/local/lib/python3.10/dist-packages (1.4.0) Requirement already satisfied: numpy<2.0,>=1.24 in /usr/local/lib/python3.10/dist-packages (from monai) (1.26.4) Requirement already satisfied: torch>=1.9 in /usr/local/lib/python3.10/dist-packages (from monai) (2.4.1+cu121) Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->monai) (3.16.1) Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->monai) (4.12.2) Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->monai) (1.13.3) Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->monai) (3.3) Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->monai) (3.1.4) Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->monai) (2024.6.1) Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.9->monai) (2.1.5) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.9->monai) (1.3.0)
Check the installation by running the following cell
import monai
monai.config.print_config()
MONAI version: 1.4.0
Numpy version: 1.26.4
Pytorch version: 2.4.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008
MONAI __file__: /usr/local/lib/python3.10/dist-packages/monai/__init__.py
Optional dependencies:
Pytorch Ignite version: 0.5.1
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.2.1
scikit-image version: 0.24.0
scipy version: 1.13.1
Pillow version: 10.4.0
Tensorboard version: 2.17.0
gdown version: 5.2.0
TorchVision version: 0.19.1+cu121
tqdm version: 4.66.5
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.5
pandas version: 2.1.4
einops version: 0.8.0
transformers version: 4.44.2
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.
For details about installing the optional dependencies, please visit:
https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
Dataset¶
To train a model, we need to prepare some ingredients:
- Dataset
- Model
- Loss function
- Optimizer
I. Create Dataset¶
There are two ways to create your dataset:
- with pytorch Dataset
- with monai.data.Dataset.
In this exercise, we will create our dataset using torch.utils.data.Dataset.
1. List all files in folder¶
Download the dataset from https://zenodo.org/record/1175282#.YMn_Qy-FDox
Notice that there are two kind of folder : original cell picture folder and mask folders. Using your file explorer or some code, display one image and the corresponding image
mask = Image.open(os.path.join("/kaggle/input/tbnc-nuclei-segmentation/TNBC_NucleiSegmentation/GT_01/", "01_1.png"))
original = Image.open(os.path.join("/kaggle/input/tbnc-nuclei-segmentation/TNBC_NucleiSegmentation/Slide_01/", "01_1.png"))
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(original, cmap='gray')
plt.title("Original Cell Picture")
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(mask, cmap='gray')
plt.title("Mask")
plt.axis('off')
plt.tight_layout()
plt.show()
2. Define a transform¶
When you load your data, you need to define some transformation. For example, we want to convert image to the format [num_channels, spatial_dim_1, spatial_dim_2] because monai/pytorch use this format. We'll also need to convert the images to PyTorch tensors with transforms.ToTensor()
The following code lets you load image and the labels and define several steps to transform the data.
image_trans = Compose(
[
LoadImage(image_only=True, reader = PILReader(converter=lambda image: image.convert("RGB"))),
EnsureChannelFirst(),
NormalizeIntensity(),
ToTensor(),
])
label_trans = Compose(
[
LoadImage(image_only=True),
EnsureChannelFirst(),
ScaleIntensity(),
ToTensor(),
])
3. Create dataset¶
The following class CellDataset allows us to create our dataset from "image_files" and "label_files" where:
- "image_files" is a list of image names
- "label_files" is the list of segmentation names respectively.
"im_trans" and "label_trans" are respectively the transforms for the images and their labels.
class CellDataset(torch.utils.data.Dataset):
def __init__(self, image_files, label_files, im_trans, label_trans):
self.image_files = image_files
self.label_files = label_files
self.im_trans = im_trans
self.label_trans = label_trans
def __len__(self):
return(len(self.image_files))
def __getitem__(self, index):
return self.im_trans(self.image_files[index]), self.label_trans(label_files[index])
By using this class, create your training dataset et your test dataset. Remember to check if your dataset is loaded correctly.
image_files = []
label_files = []
for root, dirs, files in os.walk("/kaggle/input/tbnc-nuclei-segmentation/TNBC_NucleiSegmentation", topdown=False):
current_folder = root.split("/")[-1]
for name in files:
filepath = os.path.join(root, name)
if current_folder.startswith("GT"):
label_files.append(filepath)
else:
image_files.append(filepath)
image_files = sorted(image_files, key=lambda f: re.search(r'_(\d+)', f).group(1))
label_files = sorted(label_files, key=lambda f: re.search(r'_(\d+)', f).group(1))
dataset = CellDataset(image_files=image_files, label_files=label_files, im_trans=image_trans, label_trans=label_trans) # to change
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
4. DataLoader¶
With the your dataset loaded, you have to pass it to a DataLoader. The torch.utils.data.DataLoader takes a dataset and returns batches of images and the corresponding labels. You can set various parameters like the batch size and if the data is shuffled after each epoch.
The following code let you create a data loader for the train dataset, do the same to create a test_loader on the test_dataset. Name it test_load
def plot_image_and_label(image, label, title_image="Image", title_label="Label"):
image = image.permute(1, 2, 0).cpu().numpy() if isinstance(image, torch.Tensor) else image
label = label.squeeze(0).cpu().numpy() if isinstance(label, torch.Tensor) else label
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(image)
ax[0].set_title(title_image)
ax[0].axis('off')
ax[1].imshow(label)
ax[1].set_title(title_label)
ax[1].axis('off')
plt.show()
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)
for idx, (image, label) in enumerate(train_loader):
if idx >= 5:
break
plot_image_and_label(image[0], label[0])
5. Now, time to check your dataloader.¶
Execute the code following to check if your dataloader works correctly
import monai
im, seg = monai.utils.misc.first(train_loader)
im.shape, seg.shape
(torch.Size([4, 3, 512, 512]), torch.Size([4, 1, 512, 512]))
II. Build your segmentation model with monai¶
Monai already has a UNet model architecture : https://docs.monai.io/en/stable/networks.html#unet
By using the monai.networks.nets module, build a UNet model for segmentation task in 2D. You'll have to choose the following parameters for the model:
- dimensions (number of spatial dimensions)
- in_channels (number of input channel)
- out_channels (number of output channel)
- channels
- strides
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(spatial_dims=2, in_channels=3, out_channels=1, channels=(64, 128, 256, 512), strides=(2, 2, 2)).to(device)
III. Define your loss function and optimizer¶
For a segmentation prob, we usually use DiceLoss. Using monai.losses.DiceLoss, define your loss function and store it in the variable named loss_function. The option sigmoid = True should be used.
loss_function = DiceLoss(sigmoid=True)
With torch.optim, define an optimizer for your model. Use the Adam optimiser
optimizer = Adam(model.parameters(), lr=0.01)
IV. Trainning the model¶
This time, we have all ingredients to train a segmentation model: a model, an optimizer, train_loader and a loss function.
Monai use a standard PyTorch program style for training a deep learning model.
The general process with Monai/Pytorch just for one learning step as follows:
Load input and label of each batch.
Zero accumulated gradients with optimizer.zero_grad()
Compute the output from the model
Calculate the loss
Perform backprop with loss.backward()
Update the optimizer with optimizer.step()
Complete the following code so that it do the training
epoch_loss_values = list()
for epoch in range(2):
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step += 1
inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
optimizer.zero_grad()
#compute the model predictions using the model variable and inputs
predictions = model(inputs)
# compute the loss using the loss function, the predictions and labels
loss = loss_function(predictions, labels)
# use the backward method of the loss variable to compute the gradient of the loss used to find the minimum of the loss function
loss.backward()
# call the step method of the optimizer
optimizer.step()
epoch_loss += loss.item()
epoch_len = len(train_dataset) // train_loader.batch_size
print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
1/10, train_loss: 0.8703 2/10, train_loss: 0.5666 3/10, train_loss: 0.5554 4/10, train_loss: 0.6842 5/10, train_loss: 0.4695 6/10, train_loss: 0.3653 7/10, train_loss: 0.5534 8/10, train_loss: 0.4773 9/10, train_loss: 0.6555 10/10, train_loss: 0.4701 epoch 1 average loss: 0.5667 1/10, train_loss: 0.4104 2/10, train_loss: 0.6510 3/10, train_loss: 0.3306 4/10, train_loss: 0.4400 5/10, train_loss: 0.4354 6/10, train_loss: 0.4072 7/10, train_loss: 0.4231 8/10, train_loss: 0.3322 9/10, train_loss: 0.4652 10/10, train_loss: 0.4605 epoch 2 average loss: 0.4356
Display the prediction of your model on several image
model.eval()
for batch_data in test_loader:
images, labels = batch_data[0].to(device), batch_data[1].to(device)
with torch.no_grad():
predictions = model(images)
predictions = predictions.squeeze(1)
for idx in range(len(images)):
if idx >= 5:
break
plot_image_and_label(images[idx], predictions[idx], title_image="Image", title_label="Prediction")
Train another architecture (either another Unet architecture or find another segmentation model in the available models of Monai). Compare the results with the first model
model = AttentionUnet(spatial_dims=2, in_channels=3, out_channels=1, channels=(64, 128, 256, 512), strides=(2, 2, 2)).to(device)
epoch_loss_values = list()
for epoch in range(2):
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step += 1
inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
optimizer.zero_grad()
#compute the model predictions using the model variable and inputs
predictions = model(inputs)
# compute the loss using the loss function, the predictions and labels
loss = loss_function(predictions, labels)
# use the backward method of the loss variable to compute the gradient of the loss used to find the minimum of the loss function
loss.backward()
# call the step method of the optimizer
optimizer.step()
epoch_loss += loss.item()
epoch_len = len(train_dataset) // train_loader.batch_size
print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
1/10, train_loss: 0.8724 2/10, train_loss: 0.7812 3/10, train_loss: 0.7623 4/10, train_loss: 0.8248 5/10, train_loss: 0.7452 6/10, train_loss: 0.7962 7/10, train_loss: 0.8220 8/10, train_loss: 0.8588 9/10, train_loss: 0.7705 10/10, train_loss: 0.7322 epoch 1 average loss: 0.7966 1/10, train_loss: 0.6889 2/10, train_loss: 0.8230 3/10, train_loss: 0.8052 4/10, train_loss: 0.7032 5/10, train_loss: 0.8616 6/10, train_loss: 0.7713 7/10, train_loss: 0.8579 8/10, train_loss: 0.8752 9/10, train_loss: 0.8074 10/10, train_loss: 0.7750 epoch 2 average loss: 0.7969
model.eval()
for batch_data in test_loader:
images, labels = batch_data[0].to(device), batch_data[1].to(device)
with torch.no_grad():
predictions = model(images)
predictions = predictions.squeeze(1)
for idx in range(len(images)):
if idx >= 5:
break
plot_image_and_label(images[idx], predictions[idx], title_image="Image", title_label="Prediction")